{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import json\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_accuracy_from_files_no_dict(tst_file, predictions_file, num_uniuqe_predictions=5, bias=False):\n",
    "    '''\n",
    "    This functions tests the accuracy of predictions made by Pass2Path and saved in a predictions file on\n",
    "    the input test file.\n",
    "    Inputs:\n",
    "    tst_file - txt file in which every line is of the form: original_password<tab>target_password\n",
    "    predictions_file - txt file in which every line is of the form: original_password<tab>json_list(prediction,score)\n",
    "    num_uniuqe_predictions - number of uniuqe prediction to consider in the accuracy calculation. Predictions will be taken\n",
    "                            in descending order of their score.\n",
    "    This version does not use dictionary - number of lines in both files must match.\n",
    "    Output: accuracy\n",
    "    '''\n",
    "    match_vec = []\n",
    "    with open(tst_file, 'r') as f_tst:\n",
    "        with open(predictions_file, 'r') as f_pred:\n",
    "            for l, pred_line in enumerate(f_pred):\n",
    "                orig_1, j_list = pred_line.split('\\t')\n",
    "                test_line = f_tst.readline()\n",
    "                orig_2, target =  test_line.split('\\t')\n",
    "                target = target.split('\\n')[0]\n",
    "#                 print(target)\n",
    "                if (orig_1 != orig_2):\n",
    "                    # Error\n",
    "                    print(\"Error: Lines in file do not match. Stopping...\")\n",
    "                    break\n",
    "                predictions_and_scores = json.loads(j_list)\n",
    "                predictions = [pred[0] for pred in predictions_and_scores]\n",
    "                # Take uniuqe predicions and add the original password as a guess\n",
    "                seen = set()\n",
    "                seen_add = seen.add\n",
    "                if (bias):\n",
    "                    unq_predictions = [orig_1]\n",
    "                    seen_add(orig_1)\n",
    "                else:\n",
    "                    unq_predictions = []\n",
    "                unq_predictions += [x for x in predictions if not (x in seen or seen_add(x))]\n",
    "                unq_predictions = unq_predictions[:num_uniuqe_predictions]\n",
    "                if (target in unq_predictions):\n",
    "                    match_vec.append(True)\n",
    "                else:\n",
    "                    match_vec.append(False)\n",
    "        total_samples = l + 1\n",
    "        if (total_samples != len(match_vec)):\n",
    "            print(\"Error: Total tested samples ({}) does not match the number of lines in the files ({})\"\n",
    "                  .format(len(match_vec), total_samples))\n",
    "        print(\"Accuracy calculated over {} samples\".format(total_samples))\n",
    "#         print(match_vec)\n",
    "#         print(np.array(match_vec))\n",
    "        acc = np.mean(np.array(match_vec))\n",
    "        return acc\n",
    "#                 if (l == 30):\n",
    "#                     print(orig_1 + '\\t' + ','.join(unq_predictions))\n",
    "                    \n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_accuracy_from_files(tst_file, predictions_file, num_uniuqe_predictions=5, bias=False):\n",
    "    '''\n",
    "    This functions tests the accuracy of predictions made by Pass2Path and saved in a predictions file on\n",
    "    the input test file.\n",
    "    Inputs:\n",
    "    tst_file - txt file in which every line is of the form: original_password<tab>target_password\n",
    "    predictions_file - txt file in which every line is of the form: original_password<tab>json_list(prediction,score)\n",
    "    num_uniuqe_predictions - number of uniuqe prediction to consider in the accuracy calculation. Predictions will be taken\n",
    "                            in descending order of their score.\n",
    "    This version uses dictionary - number of lines in both files may not match, accuracy is calculated on the\n",
    "                                number o samples (lines) in the prediction file.\n",
    "    Output: accuracy\n",
    "    '''\n",
    "    match_vec = []\n",
    "    test_dict = {}\n",
    "    with open(tst_file, 'r') as f_tst:\n",
    "        for line in f_tst:\n",
    "            orig, target = line.split('\\t')\n",
    "            target = target.split('\\n')[0]\n",
    "            test_dict[orig] = target\n",
    "    \n",
    "    with open(predictions_file, 'r') as f_pred:\n",
    "        for l, pred_line in enumerate(f_pred):\n",
    "                orig, j_list = pred_line.split('\\t')\n",
    "                predictions_and_scores = json.loads(j_list)\n",
    "                predictions = [pred[0] for pred in predictions_and_scores]\n",
    "                # Take uniuqe predicions and add the original password as a guess\n",
    "                seen = set()\n",
    "                seen_add = seen.add\n",
    "                if (bias):\n",
    "                    unq_predictions = [orig]\n",
    "                    seen_add(orig)\n",
    "                else:\n",
    "                    unq_predictions = []\n",
    "                unq_predictions += [x for x in predictions if not (x in seen or seen_add(x))]\n",
    "                unq_predictions = unq_predictions[:num_uniuqe_predictions]\n",
    "                target = test_dict.get(orig)\n",
    "                if (target is not None):\n",
    "                    if (target in unq_predictions):\n",
    "                        match_vec.append(True)\n",
    "                    else:\n",
    "                        match_vec.append(False)\n",
    "    total_samples = l + 1\n",
    "    print(\"Accuracy calculated over {} samples\".format(total_samples))\n",
    "#         print(match_vec)\n",
    "#         print(np.array(match_vec))\n",
    "    acc = np.mean(np.array(match_vec))\n",
    "    return acc\n",
    "#                 if (l == 30):\n",
    "#      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy calculated over 100000 samples\n",
      "0.09384\n"
     ]
    }
   ],
   "source": [
    "tst_file = 'test_full_email_100000.txt'\n",
    "predictions_file = 'pass2path_-1_test_full_email_100000.predictions'\n",
    "acc = test_accuracy_from_files(tst_file, predictions_file, num_uniuqe_predictions=10, bias=True)\n",
    "print(acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
